package com.os.core.service.extend.impl;

import com.os.common.constant.TableMatchConstant;
import com.os.common.entity.extend.ColumnInfo;
import com.os.common.entity.extend.TableMatch;
import com.os.common.entity.table.TableCreateSQL;
import com.os.common.utils.FileUtil;
import com.os.common.utils.MyUtil;
import com.os.core.datasource.DBChangeService;
import com.os.core.mapper.BaseInfoMapper;
import com.os.core.mapper.ExtendMapper;
import com.os.core.service.extend.TableMatchService;
import org.springframework.stereotype.Service;

import javax.servlet.http.HttpServletResponse;
import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;

/**
 * 描述：表匹配实现类
 *
 * @author huxuehao
 **/
@Service
public class TableMatchServiceImpl implements TableMatchService {
    private final DBChangeService dbChangeService ;
    private final ExtendMapper extendMapper;
    private final BaseInfoMapper baseInfoMapper;

    public TableMatchServiceImpl(DBChangeService dbChangeService, ExtendMapper extendMapper, BaseInfoMapper baseInfoMapper) {
        this.dbChangeService = dbChangeService;
        this.extendMapper = extendMapper;
        this.baseInfoMapper = baseInfoMapper;
    }

    @Override
    public void matchReturnFile(TableMatch tableMatch, HttpServletResponse response) throws Exception {
        byte[] buffer = new byte[1024];
        BufferedInputStream bis = null;
        OutputStream os = null;
        InputStream in = null;
        try {
            FileUtil.setDownloadHeader(response);
            os = response.getOutputStream();
            String script = matchReturnScript(tableMatch); /* 获取脚本的String */
            in = new ByteArrayInputStream(script.getBytes(StandardCharsets.UTF_8)); /* String转InputStream */
            bis = new BufferedInputStream(in);
            int len = bis.read(buffer);
            while(len != -1){
                os.write(buffer,0,len);
                len = bis.read(buffer);
            }
        } finally {
            try {
                if(bis != null) {
                    bis.close();
                }
                if(in != null) {
                    in.close();
                }
                if(os != null) {
                    os.flush();
                    os.close();
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    @Override
    public String matchReturnScript(TableMatch tableMatch) {
        List<String> sqlList = new LinkedList<>();

        String sourceDatasourceId = tableMatch.getSourceDS();
        String targetDatasourceId = tableMatch.getTargetDS();

        List<String> sourceTableNames = getTableNames(sourceDatasourceId);
        List<String> targetTableNames = getTableNames(targetDatasourceId);

        /* 目标库中多表时的操作 */
        if (!TableMatchConstant.IGNORE.equals(tableMatch.getMoreTableOpt())) {
            sqlList.addAll(dropTable(sourceTableNames, targetTableNames));
        }

        /* 目标库中少表时的操作 */
        if (!TableMatchConstant.IGNORE.equals(tableMatch.getLessTableOpt())) {
            sqlList.addAll(addTable(sourceDatasourceId, sourceTableNames, targetTableNames));
        }

        /* 目标库中表多时的字段的操作 */
        if (!TableMatchConstant.IGNORE.equals(tableMatch.getMoreColOpt())) {
            List<String> list = dropColumn(sourceDatasourceId, targetDatasourceId, sourceTableNames, targetTableNames);
            Optional.ofNullable(list).ifPresent(sqlList::addAll);
        }

        /* 目标库中表少时的字段的操作 */
        if (!TableMatchConstant.IGNORE.equals(tableMatch.getLessColOpt())) {
            List<String> list = addColumn(sourceDatasourceId, targetDatasourceId, sourceTableNames, targetTableNames);
            Optional.ofNullable(list).ifPresent(sqlList::addAll);
        }

        /* 切到默认数据源 */
        dbChangeService.changeDefaultBD();

        return String.join("\n", sqlList);
    }

    /* 获取表名 */
    private List<String> getTableNames(String datasourceId) {
        dbChangeService.changeBD(datasourceId);
        return extendMapper.getTableNames();
    }

    /* 获取sourceTableNames中多的表的创建语句 */
    private List<String> addTable(String sourceDatasourceId,
                                  List<String> sourceTableNames,
                                  List<String> targetTableNames) {
        dbChangeService.changeBD(sourceDatasourceId);
        List<String> list = new LinkedList<>();
        for (String sourceTableName : sourceTableNames) {
            if (!targetTableNames.contains(sourceTableName)) {
                TableCreateSQL createSql = baseInfoMapper.getCreateSql(sourceTableName, "BASE TABLE");
                list.add(createSql.getCreateTableSQL() + ";");
            }
        }
        return list;
    }

    /* 获取targetTableNames中多的表的删除语句 */
    private List<String> dropTable(List<String> sourceTableNames, List<String> targetTableNames) {
        List<String> list = new LinkedList<>();
        for (String targetTableName : targetTableNames) {
            if (!sourceTableNames.contains(targetTableName)) {
                list.add("DROP TABLE `"+targetTableName+"`;");
            }
        }
        return list;
    }

    /* 获取targetTableNames中少的字段的添加语句 */
    private List<String> addColumn(String sourceDatasourceId,
                                   String targetDatasourceId,
                                   List<String> sourceTableNames,
                                   List<String> targetTableNames) {
        if (MyUtil.isEmpty(targetTableNames)) {
            return null;
        }
        List<String> list = new LinkedList<>();
        for (String targetTableName : targetTableNames) {
            if (sourceTableNames.contains(targetTableName)) {
                dbChangeService.changeBD(sourceDatasourceId);
                List<ColumnInfo> sourceColumns = extendMapper.getColumnInfo(targetTableName);
                dbChangeService.changeBD(targetDatasourceId);
                List<ColumnInfo> targetColumns = extendMapper.getColumnInfo(targetTableName);
                for (ColumnInfo sourceColumn : sourceColumns) {
                    if (!targetColumns.contains(sourceColumn)) {
                        String sql = "ALTER TABLE `" + targetTableName + "` " +
                                     "ADD `" + sourceColumn.getName() + "` " + sourceColumn.getType() + ("YES".equalsIgnoreCase(sourceColumn.getNullAble()) ? " " : " NOT NULL ") +
                                     (sourceColumn.getDefaultVal() == null? " ":"DEFAULT '" + sourceColumn.getDefaultVal() + "' ")+
                                     "COMMENT '" + sourceColumn.getComment() + "';";
                        list.add(sql);
                    }
                }
            }
        }
        return list;
    }

    /* 获取targetTableNames中多的字段的删除语句 */
    private List<String> dropColumn(String sourceDatasourceId,
                                    String targetDatasourceId,
                                    List<String> sourceTableNames,
                                    List<String> targetTableNames) {
        if (MyUtil.isEmpty(targetTableNames)) {
            return null;
        }
        List<String> list = new LinkedList<>();
        for (String targetTableName : targetTableNames) {
            if (sourceTableNames.contains(targetTableName)) {
                dbChangeService.changeBD(sourceDatasourceId);
                List<ColumnInfo> sourceColumns = extendMapper.getColumnInfo(targetTableName);
                dbChangeService.changeBD(targetDatasourceId);
                List<ColumnInfo> targetColumns = extendMapper.getColumnInfo(targetTableName);
                for (ColumnInfo targetColumn : targetColumns) {
                    if (!sourceColumns.contains(targetColumn)) {
                        list.add("ALTER TABLE `" + targetTableName + "` DROP `" + targetColumn.getName() + "`;");
                    }
                }
            }
        }
        return list;
    }
}
